from yolov5_face_detection import yolov5_face_detector
from yolov3_plate_detection import yolov3_license_plate_detector
from PyQt5.QtWidgets import QApplication, QFileDialog
from PyQt5.QtGui import QPixmap
from PyQt5 import uic, QtCore
import cv2
import numpy as np
import os
from time import time
import json


def cv_img_read(file_path):
    # cv2.imread读取中文路径出错，替代
    # cv_img = cv2.imdecode(np.fromfile(file_path, dtype=np.uint8), -1)
    cv_img = cv2.imdecode(np.fromfile(file_path, dtype=np.uint8), cv2.IMREAD_COLOR)
    return cv_img


def cv_img_write(output_path, image):
    # cv2.imwrite写入中文路径出错，替代
    cv2.imencode('.jpg', image)[1].tofile(output_path)


class Data_Mask:
    def __init__(self):
        self.ui = uic.loadUi("./utils/data_mask_demo_v1.ui")
        self.ui.workmode_combo.addItems(['Data Mask', 'Face/License Plate Detect',
                                         'Face Mosaic', 'Face Detect',
                                         'License Plate Mosaic', 'License Plate Detect'])
        self.image_file = ''
        self.result_file = ''
        self.result_num = 0
        self.mosaic_level = 2
        self.work_mode = 'Data Mask'
        self.mid_name = '-data_mask'

        self.ui.mosaic_level_slider.setMinimum(1)
        self.ui.mosaic_level_slider.setMaximum(5)
        self.ui.mosaic_level_slider.setSingleStep(1)
        self.ui.mosaic_level_slider.setValue(2)
        self.ui.mosaic_level_slider.valueChanged.connect(self.mosaic_level_change)
        self.ui.mosaic_level_textbrowser.setText(str(self.mosaic_level))

        self.ui.read_image_button.clicked.connect(self.get_image)
        self.ui.show_result_button.clicked.connect(self.show_result)
        self.ui.save_result_button.clicked.connect(self.save_result)
        self.ui.input_button.clicked.connect(self.get_input_path)
        self.ui.output_button.clicked.connect(self.get_output_path)
        self.ui.start_button.clicked.connect(self.forward_batch)

        # --------------face_detect_config_start----------------------
        self.ui.yolomode_combo.addItems(['yolov5s', 'yolov5m', 'yolov5l'])
        self.ui.workmode_combo.currentIndexChanged.connect(self.work_mode_change)
        self.ui.yolomode_combo.currentIndexChanged.connect(self.yolo_mode_change)

        self.ui.class_conf_dsbox.valueChanged.connect(self.yolo_params_change)
        self.ui.obj_conf_dsbox.valueChanged.connect(self.yolo_params_change)
        self.ui.nms_iou_dsbox.valueChanged.connect(self.yolo_params_change)

        self.yolo_mode_face = 'yolov5s'
        self.face_yolo_conf_thresh = 0.5
        self.face_yolo_obj_thresh = 0.5
        self.face_yolo_nms_iou = 0.5
        self.face_detector = yolov5_face_detector(yolo_type=self.yolo_mode_face)

        self.ui.class_conf_dsbox.setValue(self.face_yolo_conf_thresh)
        self.ui.obj_conf_dsbox.setValue(self.face_yolo_obj_thresh)
        self.ui.nms_iou_dsbox.setValue(self.face_yolo_nms_iou)
        # --------------face_detect_config_end----------------------

        # --------------plate_detect_config_start----------------------
        self.plate_yolo_conf_thresh = 0.5
        self.plate_yolo_nms_iou = 0.4
        self.ui.plate_class_dsbox.setValue(self.plate_yolo_conf_thresh)
        self.ui.plate_nms_dsbox.setValue(self.plate_yolo_nms_iou)
        self.plate_detector = yolov3_license_plate_detector()
        self.ui.plate_class_dsbox.valueChanged.connect(self.plate_yolo_params_change)
        self.ui.plate_nms_dsbox.valueChanged.connect(self.plate_yolo_params_change)
        # --------------plate_detect_config_end----------------------

    def work_mode_change(self):
        self.work_mode = self.ui.workmode_combo.currentText()
        if self.work_mode == "Face Detect":
            self.mid_name = "-face_detect"
        elif self.work_mode == "Face Mosaic":
            self.mid_name = "-face_mosaic"
        elif self.work_mode == "Data Mask":
            self.mid_name = "-data_mask"
        elif self.work_mode == "Face/License Plate Detect":
            self.mid_name = "-detect_both"
        elif self.work_mode == "License Plate Mosaic":
            self.mid_name = "-plate_mosaic"
        elif self.work_mode == "License Plate Detect":
            self.mid_name = "-plate_detect"
        self.printf(f"\nSet work mode '{self.work_mode}'")

    def yolo_mode_change(self):
        self.yolo_mode_face = self.ui.yolomode_combo.currentText()
        self.printf(f"\nSet yolo mode '{self.yolo_mode_face}'")
        self.face_detector = yolov5_face_detector(yolo_type=self.yolo_mode_face,
                                                  confThreshold=self.face_yolo_conf_thresh,
                                                  nmsThreshold=self.face_yolo_nms_iou,
                                                  objThreshold=self.face_yolo_obj_thresh)

    def yolo_params_change(self):
        self.face_yolo_conf_thresh = self.ui.class_conf_dsbox.value()
        self.face_yolo_obj_thresh = self.ui.obj_conf_dsbox.value()
        self.face_yolo_nms_iou = self.ui.nms_iou_dsbox.value()
        self.face_detector = yolov5_face_detector(yolo_type=self.yolo_mode_face,
                                                  confThreshold=self.face_yolo_conf_thresh,
                                                  nmsThreshold=self.face_yolo_nms_iou,
                                                  objThreshold=self.face_yolo_obj_thresh)

    def plate_yolo_params_change(self):
        self.plate_yolo_conf_thresh = self.ui.plate_class_dsbox.value()
        self.plate_yolo_nms_iou = self.ui.plate_nms_dsbox.value()
        self.plate_detector = yolov3_license_plate_detector(confThreshold=self.plate_yolo_conf_thresh,
                                                            nmsThreshold=self.plate_yolo_nms_iou)

    def mosaic_level_change(self):
        self.mosaic_level = self.ui.mosaic_level_slider.value()
        self.ui.mosaic_level_textbrowser.setText(str(self.mosaic_level))

    def work_process(self, srcimg):
        face_num = 0
        plate_num = 0
        if self.work_mode in ['Face Mosaic', 'Face Detect']:
            face_dets = self.face_detector.detect(srcimg)
            if self.work_mode == "Face Detect":
                result_file, face_num, face_boxes = self.face_detector.postprocess(srcimg, face_dets, mark='origin')
                plate_boxes = Noneplate_boxes = None
            elif self.work_mode == "Face Mosaic":
                mosaic_level_in = (self.mosaic_level + 4) * 0.1
                result_file, face_num, face_boxes = self.face_detector.postprocess(srcimg, face_dets, mark='mosaic',
                                                                       mosaic_level=mosaic_level_in)
                plate_boxes = None
        elif self.work_mode in ['License Plate Mosaic', 'License Plate Detect']:
            if self.work_mode == 'License Plate Detect':
                result_file, plate_num, plate_boxes = self.plate_detector.rectangle(srcimg, mode='origin')
                face_boxes = None
            elif self.work_mode == 'License Plate Mosaic':
                mosaic_level_in = (self.mosaic_level + 4) * 0.1
                result_file, plate_num, plate_boxes = self.plate_detector.rectangle(srcimg, mode='mosaic',
                                                                       mosaic_level=mosaic_level_in)
                face_boxes = None
        elif self.work_mode == "Data Mask":
            face_dets = self.face_detector.detect(srcimg)
            mosaic_level_in = (self.mosaic_level + 4) * 0.1
            result_file, face_num, face_boxes = self.face_detector.postprocess(srcimg, face_dets, mark='mosaic',
                                                                   mosaic_level=mosaic_level_in)
            result_file, plate_num, plate_boxes = self.plate_detector.rectangle(result_file, mode='mosaic',
                                                                   mosaic_level=mosaic_level_in)
        elif self.work_mode == "Face/License Plate Detect":
            face_dets = self.face_detector.detect(srcimg)
            result_file, face_num, face_boxes = self.face_detector.postprocess(srcimg, face_dets, mark='origin')
            result_file, plate_num, plate_boxes = self.plate_detector.rectangle(result_file, mode='origin')

        self.result_file = result_file
        return result_file, face_num, plate_num, face_boxes, plate_boxes

    def get_image(self):
        self.image_file, image_type = QFileDialog.getOpenFileName(
            self.ui,  # 父窗口对象
            "",
            "选择图片",  # 标题"*.jpg;;*.png;;All Files(*)"
            "(*.png *.jpg *.jpeg);;All Files(*)"  # 选择类型过滤项，过滤内容在括号中
        )
        img = QPixmap(self.image_file).scaled(self.ui.origin_image_label.width(),
                                              self.ui.origin_image_label.height())
        self.ui.origin_image_label.setPixmap(img)
        self.ui.origin_image_label.setScaledContents(True)
        self.printf(f"\nGet image {self.image_file}")

    def show_result(self):
        if self.image_file != '':
            srcimg = cv_img_read(self.image_file)
            if srcimg is not None:
                result_file, face_num, plate_num, face_boxes, plate_boxes = self.work_process(srcimg=srcimg)
                result_image_file = os.path.join(os.getcwd(), "result.jpg")
                cv2.imwrite(result_image_file, result_file)
                img = QPixmap(result_image_file).scaled(self.ui.result_image_label.width(),
                                                        self.ui.result_image_label.height())
                self.ui.result_image_label.setPixmap(img)
                self.ui.result_image_label.setScaledContents(True)

                self.printf(f"\n{face_num} faces detected in the image")
                self.printf(f"\n{plate_num} license plates detected in the image")
            else:
                self.printf(f"\nPlease select the image file!")
        else:
            self.printf("\nWARNING: Please get image file first!")

    def save_result(self):
        if self.result_file != '':
            get_directory_path = QFileDialog.getExistingDirectory(self.ui,
                                                                  "选取指定文件夹", )
            image_orig_name = self.image_file.split('/')[-1].split('.')[0]
            result_save_file_name = image_orig_name + self.mid_name + '.jpg'
            result_save_image_file = os.path.join(get_directory_path, result_save_file_name)
            cv_img_write(result_save_image_file, self.result_file)
            self.printf(f"\nSave current result file to {get_directory_path}")
        else:
            self.printf("\nWARNING: Please get result of image file first!")

    def get_input_path(self):
        get_directory_path = QFileDialog.getExistingDirectory(self.ui,
                                                              "选取指定文件夹", )
        self.ui.input_dir.setText(str(get_directory_path))
        self.printf(f"\nSet images input path: '{get_directory_path}'")

    def get_output_path(self):
        get_directory_path = QFileDialog.getExistingDirectory(self.ui,
                                                              "选取指定文件夹", )
        self.ui.output_dir.setText(str(get_directory_path))
        self.printf(f"\nSet images output path: '{get_directory_path}'")

    def forward_batch(self):
        input_path = self.ui.input_dir.text()
        output_path = self.ui.output_dir.text()
        if input_path != '' and output_path != '':
            # traverse every image in input path
            self.printf("\nBatch Process Start...")
            images_list = os.listdir(input_path)
            files_len = len(images_list)
            self.ui.progressBar.setMaximum(files_len)
            start_time = time()
            not_image_num = 0
            dir_num = 0
            process_face_image_num = 0
            process_plate_image_num = 0
            process_image_num = 0
            out_dict = {}
            for idx, img_name in enumerate(images_list):
                self.ui.progressBar.setValue(idx + 1)
                # get image
                img_path = os.path.join(input_path, img_name)
                if os.path.isfile(img_path):
                    srcimg = cv_img_read(img_path)
                    if srcimg is not None:
                        img_new_name = img_name.split('.')[0] + self.mid_name + '.jpg'
                        srcimg, face_num, plate_num, face_boxes, plate_boxes = self.work_process(srcimg=srcimg)
                        if face_num >= 1 or plate_num >= 1:
                            if face_num >= 1:
                                process_face_image_num += 1
                            if plate_num >= 1:
                                process_plate_image_num += 1
                            process_image_num += 1
                            cv_img_write(os.path.join(output_path, img_new_name), srcimg)
                            out_dict[f'{img_new_name}'] = f'face:{face_boxes}  plate:{plate_boxes}'
                    else:
                        not_image_num += 1
                else:
                    dir_num += 1
            json_write = json.dumps(out_dict, indent=1)
            with open(f'{output_path}/result.json', 'w') as jw:
                jw.write(json_write)
            end_time = time()
            duration = end_time - start_time
            self.printf("\nBatch Process Success.")
            self.printf(f"\nImage(with object) num: {process_image_num}")
            self.printf(f"Image(with face) num: {process_face_image_num}")
            self.printf(f"Image(with license plate) num: {process_plate_image_num}")
            self.printf(f"\nImage total num: {files_len - not_image_num - dir_num}")
            self.printf(f"Directory num: {dir_num}")
            self.printf(f"Other file num: {not_image_num}")
            self.printf(f"Total num: {files_len}")
            self.printf(f"\nTime spend: {duration:.3f}s")
        else:
            self.printf("\nWARNING: Please select both input path and output path!")

    def printf(self, mes):
        self.ui.message_window_textbrowser.append(mes)  # 在指定的区域显示提示信息
        self.ui.message_window_textbrowser.moveCursor(self.ui.message_window_textbrowser.textCursor().End)
        QApplication.processEvents()


if __name__ == "__main__":
    QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling)  # 自适应屏幕分辨率
    app = QApplication([])
    Dd = Data_Mask()
    Dd.ui.show()
    app.exec()
